import os
import pathlib
import platform
import shutil
import time
import warnings

try:
    from importlib.resources import files
except ImportError:
    from importlib_resources import files  # type: ignore[no-redef]

from typing import Union, Optional, Callable, Text, Tuple

import casadi as cs
import torch

try:
    from torch.func import jacrev, jacfwd, functionalize, vjp
except ImportError:
    from functorch import jacrev, jacfwd, functionalize, vjp
from l4casadi.ts_compiler import ts_compile
from torch.fx.experimental.proxy_tensor import make_fx

from l4casadi.template_generation import render_casadi_c_template
from l4casadi.naive import NaiveL4CasADiModule


def dynamic_lib_file_starting():
    system = platform.system()
    if system == 'Darwin':
        return 'lib'
    elif system == 'Linux':
        return 'lib'
    elif system == 'Windows':
        return ''
    

def dynamic_lib_file_ending():
    system = platform.system()
    if system == 'Darwin':
        return '.dylib'
    elif system == 'Linux':
        return '.so'
    elif system == 'Windows':
        return '.dll'


class L4CasADi(object):
    def __init__(self,
                 model: Callable[[torch.Tensor], torch.Tensor],
                 batched: bool = False,
                 device: Union[torch.device, Text] = 'cpu',
                 name: Text = 'l4casadi_f',
                 build_dir: Text = './_l4c_generated',
                 model_search_path: Optional[Text] = None,
                 generate_jac: bool = True,
                 generate_adj1: bool = True,
                 generate_jac_adj1: bool = True,
                 generate_jac_jac: bool = False,
                 scripting: bool = True,
                 mutable: bool = False):
        """
        :param model: PyTorch model.
        :param batched: If True, the first dimension of the two expected input dimension is assumed to be a batch
            dimension. This can lead to speedups as sensitivities across this dimension can be neglected.
        :param device: Device on which the PyTorch model is executed.
        :param name: Unique name of the generated L4CasADi model. This name is used for autogenerated files.
            Creating two L4CasADi models with the same name will result in overwriting the files of the first model.
        :param model_search_path: Path to the directory where the PyTorch model can be found. By default, this will be
            the absolute path to the `build_dir` where the model traces are exported to. This parameter can become
            useful if the created L4CasADi dynamic library and the exported PyTorch Models are expected to be moved to a
            different folder (or another device).
        :param build_dir: Directory where the L4CasADi library is built.
        :param generate_jac: If True, the Jacobian of the model is tried to be generated.
        :param generate_adj1: If True, the Adjoint of the model is tried to be generated.
        :param generate_jac_adj1: If True, the Jacobain of the Adjoint of the model is tried to be generated.
        :param generate_jac_jac: If True, the Hessian of the model is tried to be generated.
        :param scripting: If True, the model is traced using TorchScript. If False, the model is compiled.
        :param mutable: If True, enables updating the model online via the update method.
        """
        if not scripting:
            warnings.warn("L4CasADi with Torch AOT compilation is experimental at this point and might not work as "
                          "expected.")
            raise RuntimeError("PyTorch compile is not supported yet as it does not seem stable.")
            if torch.__version__ < torch.torch_version.TorchVersion('2.4.0'):
                raise RuntimeError("For PyTorch versions < 2.4.0 L4CasADi only supports jit scripting. Please pass "
                                   "scripting=True.")
            import torch._inductor.config as config
            config.freezing = True

        self.model = model
        self.naive = False
        if isinstance(self.model, NaiveL4CasADiModule):
            self.naive = True
        if isinstance(self.model, torch.nn.Module):
            self.model.eval().to(device)
            for parameters in self.model.parameters():
                parameters.requires_grad = False
        self.name = name
        self.batched = batched
        self.device = device if isinstance(device, str) else f'{device.type}:{device.index}'

        self.build_dir = pathlib.Path(build_dir)

        self._model_search_path = model_search_path

        self._cs_fun: Optional[cs.Function] = None
        self._built = False

        self._generate_jac = generate_jac
        self._generate_adj1 = generate_adj1
        self._generate_jac_adj1 = generate_jac_adj1
        self._generate_jac_jac = generate_jac_jac

        self._scripting = scripting

        if mutable and platform.system() == "Windows":
            raise RuntimeError('Online model update (mutable=True) is not supported on Windows.')
        self._mutable = mutable

        self._input_shape: Tuple[int, int] = (-1, -1)
        self._output_shape: Tuple[int, int] = (-1, -1)

    def update(self, model: Optional[Callable[[torch.Tensor], torch.Tensor]] = None) -> None:
        """
        Updates the PyTorch model online.
        :param model: Optional, new model. If None, old reference to model given to initializer will be used for update.
        """

        if not self._mutable:
            raise RuntimeError('To update the model online, please initialize L4CasADi with `mutable=True`.')

        if not self._built:
            raise RuntimeError('L4CasADi has to be built first before the model can be updated.')

        if model is not None:
            self.model = model
            if isinstance(self.model, torch.nn.Module):
                self.model.eval().to(self.device)
                for parameters in self.model.parameters():
                    parameters.requires_grad = False

        self.export_torch_traces()  # type: ignore[misc]

        time.sleep(0.2)

        # Create file on disk to signal C++ to reload the model
        open(self.build_dir / f'{self.name}.reload', 'w').close()

        # Wait until C++ has reloaded the model
        while os.path.exists(self.build_dir / f'{self.name}.reload'):
            time.sleep(0.01)

    def __call__(self, *args):
        return self.forward(*args)

    @property
    def shared_lib_dir(self):
        return self.build_dir.absolute().as_posix()

    def forward(self, inp: Union[cs.MX, cs.SX, cs.DM]):
        if self.naive:
            out = self.model(inp)
        else:
            if not self._built:
                self.build(inp)
            if self._cs_fun is None:
                self._load_built_library_as_external_cs_fun()

            out = self._cs_fun(inp)  # type: ignore[misc]

        return out

    def maybe_make_generation_dir(self):
        if not os.path.exists(self.build_dir):
            os.makedirs(self.build_dir)

    def build(self, inp: Union[cs.MX, cs.SX, cs.DM]) -> None:
        """Builds the L4CasADi model as dynamic library.

        1. Exports the traced PyTorch model to TorchScript.
        2. Fills the C++ template with model parameters and paths to TorchScript.
        3. Compiles the C++ template to a dynamic library.
        4. Loads the dynamic library as CasADi external function.

        :param inp: Symbolic model input. Used to infer expected input shapes.
        """

        self.maybe_make_generation_dir()

        # TODO: The naive case could potentially be removed. Not sure if there exists a use-case for this.
        if self.naive:
            rows, cols = inp.shape  # type: ignore[attr-defined]
            inp_sym = cs.MX.sym('inp', rows, cols)
            out_sym = self.model(inp_sym)
            cs.Function(f'{self.name}', [inp_sym], [out_sym]).generate(f'{self.name}.cpp')
            shutil.move(f'{self.name}.cpp', (self.build_dir / f'{self.name}.cpp').as_posix())
        else:
            self.generate(inp)

        self.compile()

        self._built = True

    def _verify_input_output(self):
        if len(self._output_shape) != 2:
            raise ValueError(f"""L4CasADi requires the model output to be a matrix (2 dimensions) but has 
                              {len(self._output_shape)} dimensions. Please add a extra dimension of size 1. 
                              For models which expects a batch dimension, the output should be a matrix of [1, d].""")

        if self.batched:
            if self._input_shape[0] != self._output_shape[0]:
                raise ValueError(f"""When the model is batched the first dimension of input and output (batch dimension)
                                    has to be the same.""")

    def generate(self, inp: Union[cs.MX, cs.SX, cs.DM]) -> None:
        self._input_shape = inp.shape  # type: ignore[attr-defined]
        self._output_shape = self.model(torch.zeros(*self._input_shape).to(self.device)).shape
        self._verify_input_output()

        has_jac, has_adj1, has_jac_adj1, has_jac_jac = self.export_torch_traces()
        if not has_jac and self._generate_jac:
            warnings.warn('Jacobian trace could not be generated.'
                          ' First-order sensitivities will not be available in CasADi.')
        if not has_adj1 and self._generate_adj1:
            warnings.warn('Adjoint trace could not be generated.'
                          ' First-order sensitivities will not be available in CasADi.')
        if not has_jac_adj1 and self._generate_jac_adj1:
            warnings.warn('Jacobian Adjoint trace could not be generated.'
                          ' Second-order sensitivities will not be available in CasADi.')
        if not has_jac_jac and self._generate_jac_jac:
            warnings.warn('Hessian trace could not be generated.'
                          ' Second-order sensitivities will not be available in CasADi.')
        self._generate_cpp_function_template(has_jac, has_adj1, has_jac_adj1, has_jac_jac)

    def _load_built_library_as_external_cs_fun(self):
        if not self._built:
            raise RuntimeError('L4CasADi model has not been built yet. Call `build` first.')
        self._cs_fun = cs.external(
            f'{self.name}',
            f"{self.build_dir / f'{dynamic_lib_file_starting()}{self.name}'}{dynamic_lib_file_ending()}"
        )

    @staticmethod
    def generate_block_diagonal_ccs(batch_size, input_size, output_size):
        """
        https://de.wikipedia.org/wiki/Harwell-Boeing-Format
        :param batch_size: Size of batch dimension.
        :param input_size: Size of input vector.
        :param output_size: Size of output vector.
        :return:
            jac_ccs, hess_ccs
        """
        # Jacobian dimensions [batch_size * output_size, batch_size * input_size]
        col_ptr = list(range(0, batch_size * input_size * output_size, output_size)) + [
            batch_size * input_size * output_size]
        row_ind = []
        for _ in range(input_size):
            for batch_idx in range(batch_size):
                row_ind += list(range(batch_idx, batch_idx + batch_size * output_size, batch_size))

        jac_ccs = [batch_size * output_size, batch_size * input_size] + col_ptr + row_ind

        # Hessian dimensions [batch_size * output_size * batch_size * input_size, batch_size * input_size]
        col_ptr = list(range(0, batch_size * input_size * output_size * input_size, input_size * output_size)) + [
            batch_size * input_size * output_size * input_size]
        row_ind = []
        for _ in range(input_size):
            for batch_idx in range(batch_size):
                for jacobian_idx in range(0, batch_size * output_size * batch_size * input_size,
                                          output_size * batch_size * batch_size):
                    row_ind += list(range(jacobian_idx + batch_idx * batch_size * output_size + batch_idx,
                                          (jacobian_idx + batch_idx * batch_size * output_size
                                           + batch_idx + batch_size * output_size),
                                          batch_size))

        hess_ccs = [batch_size * output_size * batch_size * input_size, batch_size * input_size] + col_ptr + row_ind

        return jac_ccs, hess_ccs

    def _generate_cpp_function_template(self, has_jac: bool, has_adj1: bool, has_jac_adj1: bool, has_jac_jac: bool):

        model_path = (self.build_dir.absolute().as_posix()
                      if self._model_search_path is None
                      else self._model_search_path)

        if self.batched:
            jac_ccs, jac_jac_ccs = self.generate_block_diagonal_ccs(self._input_shape[0],
                                                                    self._input_shape[1],
                                                                    self._output_shape[1])
            jac_adj_css, _ = self.generate_block_diagonal_ccs(self._input_shape[0],
                                                              self._input_shape[1],
                                                              self._input_shape[1])
        else:
            jac_ccs, jac_adj_css, jac_jac_ccs = None, None, None

        gen_params = {
            'model_path': model_path,
            'device': self.device,
            'name': self.name,
            'rows_in': self._input_shape[0],
            'cols_in': self._input_shape[1],
            'rows_out': self._output_shape[0],
            'cols_out': self._output_shape[1],
            'has_jac': 'true' if has_jac else 'false',
            'has_adj1': 'true' if has_adj1 else 'false',
            'has_jac_adj1': 'true' if has_jac_adj1 else 'false',
            'has_jac_jac': 'true' if has_jac_jac else 'false',
            'scripting': 'true' if self._scripting else 'false',
            'model_is_mutable': 'true' if self._mutable else 'false',
            'batched': 'true' if self.batched else 'false',
            'jac_ccs_len': len(jac_ccs) if self.batched else 0,
            'jac_ccs': ', '.join(str(e) for e in jac_ccs) if self.batched else '',
            'jac_adj_ccs_len': len(jac_adj_css) if self.batched else 0,
            'jac_adj_ccs': ', '.join(str(e) for e in jac_adj_css) if self.batched else '',
            'jac_jac_ccs_len': len(jac_jac_ccs) if self.batched else 0,
            'jac_jac_ccs': ', '.join(str(e) for e in jac_jac_ccs) if self.batched else '',
        }

        render_casadi_c_template(
            variables=gen_params,
            out_file=(self.build_dir / f'{self.name}.cpp').as_posix()
        )

    def compile(self):
        file_dir = files('l4casadi')
        include_dir = files('l4casadi') / 'include'
        lib_dir = file_dir / 'lib'

        # If cmake is available on the system, use it to compile the dynamic library
        if platform.system() != "Windows" and shutil.which("gcc"):
            # If cmake is not installed, fall back to manual compilation using gcc (previous implementation)
            soname = 'install_name' if platform.system() == 'Darwin' else 'soname'
            cxx11_abi = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
            link_libl4casadi = " -ll4casadi" if not self.naive else ""
            os_cmd = ("gcc"
                    " -fPIC -shared"
                    f" {self.build_dir / self.name}.cpp"
                    f" -o {self.build_dir / f'lib{self.name}'}{dynamic_lib_file_ending()}"
                    f" -I{include_dir} -L{lib_dir}"
                    f" -Wl,-{soname},{dynamic_lib_file_starting()}{self.name}{dynamic_lib_file_ending()}"
                    f"{link_libl4casadi}"
                    " -lstdc++ -std=c++17"
                    f" -D_GLIBCXX_USE_CXX11_ABI={cxx11_abi}")

        elif shutil.which("cmake"):
            # get current working dir as posix
            cwd = pathlib.Path('.').absolute()

            linked_lib = f"target_link_libraries({self.name} l4casadi)" if not self.naive else ""
            glibcxx_use_cxx11_abi = (
                "" 
                if platform.system == "Windows" else 
                f"add_definitions(-D_GLIBCXX_USE_CXX11_ABI={1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0})"
            )
            
            cmake_content = f"""
                cmake_minimum_required(VERSION 3.15)
                project({self.name})
                set(CMAKE_CXX_STANDARD 17)
                set(CMAKE_POSITION_INDEPENDENT_CODE ON)
                
                include_directories({include_dir.as_posix()})
                link_directories({lib_dir.as_posix()})

                add_library({self.name} SHARED {self.name}.cpp)
                {linked_lib}
                
                set_target_properties({self.name} PROPERTIES
                    LIBRARY_OUTPUT_DIRECTORY {(cwd / self.build_dir).as_posix()}
                    RUNTIME_OUTPUT_DIRECTORY {(cwd / self.build_dir).as_posix()}
                )
                {glibcxx_use_cxx11_abi}

                install(TARGETS {self.name} DESTINATION {(cwd / self.build_dir).as_posix()})
            """
            with open(self.build_dir / "CMakeLists.txt", "w") as f:
                f.write(cmake_content)  

            os_cmd = (
                f"cmake -S {self.build_dir} -B {self.build_dir} -DCMAKE_RULE_MESSAGES=OFF && "
                f"cmake --build {self.build_dir} --config=Release"
            )
            os_cmd += f"&&  cmake --install {self.build_dir} --config=Release" if platform.system() == "Windows" else ""
        else:
            raise RuntimeError("Please install gcc (Linux and Mac) or cmake (Windows, Linux and Mac) to compile the dynamic library.")
            
        status = os.system(os_cmd)
        if status != 0:
            raise Exception(f'Compilation failed!\n\nAttempted to execute OS command:\n{os_cmd}\n\n')

    def _trace_jac_model(self, inp):
        if self.batched:
            def with_batch_dim(x):
                return torch.func.vmap(jacrev(self.model))(x[:, None])[:, 0].permute(1, 0, 2, 3)

            return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(inp)
        return make_fx(functionalize(jacrev(self.model), remove='mutations_and_views'))(inp)

    def _trace_adj1_model(self):
        p_d = torch.zeros(self._input_shape).to(self.device)
        t_d = torch.zeros(self._output_shape).to(self.device)

        def _vjp(p, x):
            return vjp(self.model, p)[1](x)[0]

        return make_fx(functionalize(_vjp, remove='mutations_and_views'))(p_d, t_d)

    def _trace_jac_adj1_p_model(self):
        p_d = torch.zeros(self._input_shape).to(self.device)
        t_d = torch.zeros(self._output_shape).to(self.device)

        def _vjp(p, x):
            return vjp(self.model, p)[1](x)[0]

        # TODO: replace jacfwd with jacrev depending on answer in https://github.com/pytorch/pytorch/issues/130735
        if self.batched:
            def with_batch_dim(p, x):
                return torch.func.vmap(jacfwd(_vjp, argnums=0))(p[:, None], x[:, None])[:, 0].permute(3, 2, 0, 1)

            return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(p_d, t_d)
        return make_fx(functionalize(jacfwd(_vjp, argnums=0), remove='mutations_and_views'))(p_d, t_d)

    def _trace_jac_adj1_t_model(self):
        p_d = torch.zeros(self._input_shape).to(self.device)
        t_d = torch.zeros(self._output_shape).to(self.device)

        def _vjp(p, x):
            return vjp(self.model, p)[1](x)[0]

        # TODO: replace jacfwd with jacrev depending on answer in https://github.com/pytorch/pytorch/issues/130735
        if self.batched:
            def with_batch_dim(p, x):
                return torch.func.vmap(jacfwd(_vjp, argnums=1))(p[:, None], x[:, None])[:, 0].permute(3, 2, 0, 1)

            return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(p_d, t_d)
        return make_fx(functionalize(jacfwd(_vjp, argnums=1), remove='mutations_and_views'))(p_d, t_d)

    def _trace_hess_model(self, inp):
        if self.batched:
            def with_batch_dim(x):
                # Permutation is trial and error
                return torch.func.vmap(jacrev(jacrev(self.model)))(x[:, None])[:, 0].permute(1, 3, 2, 0, 4, 5)

            return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(inp)
        return make_fx(functionalize(jacrev(jacrev(self.model)), remove='mutations_and_views'))(inp)

    def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]:
        d_inp = torch.zeros(self._input_shape)
        d_inp = d_inp.to(self.device)

        d_out = torch.zeros(self._output_shape)
        d_out = d_out.to(self.device)

        out_folder = self.build_dir

        self.model_compile(make_fx(functionalize(self.model, remove='mutations_and_views'))(d_inp),
                                   (out_folder / f'{self.name}.pt').as_posix(),
                                   (d_inp,))

        exported_jac = False
        if self._generate_jac:
            jac_model = self._trace_jac_model(d_inp)

            exported_jac = self.model_compile(
                jac_model,
                (out_folder / f'jac_{self.name}.pt').as_posix(),
                (d_inp,)
            )

        exported_adj1 = False
        if self._generate_adj1:
            adj1_model = self._trace_adj1_model()
            exported_adj1 = self.model_compile(
                adj1_model,
                (out_folder / f'adj1_{self.name}.pt').as_posix(),
                (d_inp, d_out)
            )

        exported_jac_adj1 = False
        if self._generate_jac_adj1:
            jac_adj1_p_model = self._trace_jac_adj1_p_model()
            exported_jac_adj1_p = self.model_compile(
                jac_adj1_p_model,
                (out_folder / f'jac_adj1_p_{self.name}.pt').as_posix(),
                (d_inp, d_out)
            )

            jac_adj1_t_model = self._trace_jac_adj1_t_model()
            exported_jac_adj1_t = self.model_compile(
                jac_adj1_t_model,
                (out_folder / f'jac_adj1_t_{self.name}.pt').as_posix(),
                (d_inp, d_out)
            )

            exported_jac_adj1 = exported_jac_adj1_p and exported_jac_adj1_t

        exported_hess = False
        if self._generate_jac_jac:
            hess_model = None
            try:
                hess_model = self._trace_hess_model(d_inp)
            except:  # noqa
                pass

            if hess_model is not None:
                exported_hess = self.model_compile(
                    hess_model,
                    (out_folder / f'jac_jac_{self.name}.pt').as_posix(),
                    (d_inp,)
                )

        return exported_jac, exported_adj1, exported_jac_adj1, exported_hess

    def model_compile(self, model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]):
        if self._scripting:
            return self._jit_compile_and_save(model, file_path, dummy_inp)
        else:
            return self._aot_compile_and_save(model, file_path, dummy_inp)

    @staticmethod
    def _aot_compile_and_save(model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]):
        try:
            with torch.no_grad():
                torch._export.aot_compile(
                    model,
                    dummy_inp,
                    options={"aot_inductor.output_path": file_path[:-2] + 'so'},
                )
            return True
        except:  # noqa
            return False

    @staticmethod
    def _jit_compile_and_save(model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]):
        try:
            # Try scripting
            ts_compile(model).save(file_path)
        except:  # noqa
            # Try tracing
            try:
                torch.jit.trace(model, dummy_inp).save(file_path)
            except:  # noqa
                return False
        return True
